#!/usr/bin/env python3
# --------------------------------------------------------------------------------------
# Imports & Config
# --------------------------------------------------------------------------------------
import argparse
import json
import os
import re
import sys
from typing import Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# --------------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------------
model, tokenizer = None, None

# --------------------------------------------------------------------------------------
# Single system prompt (no persona)
# --------------------------------------------------------------------------------------

SYSTEM_PROMPT = (
    "You are an expert digital advertisement analyst. "
    "Given the ad description below, predict the Click-Through-Rate percentile it will achieve (0-100).\n\n"
    "Return your response in exactly two lines:\n"
    "Answer: <0-100>\n"
    "Reason: <brief justification>"
)

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(full_prompt: str, model, tokenizer, args) -> str:
    """Call Qwen model for chat completion with a combined prompt."""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": "/no_think" + full_prompt},
    ]
    # Use apply_chat_template
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False
    )
    input_ids = input_ids.to(model.device)  # Ensure input_ids are on the same device as the model

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=1200,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1
        )
    
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run GMO evaluation over a slice of a campaign dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="gmo_results", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--dataset_paths", type=str, required=True, help="Comma-separated list of *.jsonl datasets to evaluate.")
    parser.add_argument("--max_examples", type=int, default=None, help="(Optional) truncate dataset to this many examples – useful for quick smoke tests.")
    # --tweet_eval is kept for compatibility with the run.sh script, but could be removed
    # if the runner is also updated to remove it.
    parser.add_argument("--tweet_eval", action="store_true", help="Legacy flag to trigger this evaluation path.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic
# --------------------------------------------------------------------------------------

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # -----------------------------------------------------------------------------
    # Load Model once at the start
    # -----------------------------------------------------------------------------
    model_name = "Qwen/Qwen3-32B" #change to meta-llama/Llama-3.3-70B-Instruct for LlaMA
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # The script now only performs one task, so we can call it directly.
    run_gmo_evaluation(args)


def run_gmo_evaluation(args):
    """
    End-to-end evaluation on GMO datasets containing {"prompt":..., "response":...} per line.
    The 'prompt' field contains ICL examples and the final query.
    """
    global model, tokenizer
    # Resolve dataset paths
    if args.dataset_paths:
        dset_paths = [p.strip() for p in args.dataset_paths.split(",") if p.strip()]
    else:
        # This path should not be taken if run via run.sh
        print("[ERROR] --dataset_paths is a required argument.", file=sys.stderr)
        sys.exit(1)

    overall_out_dir = args.output_dir
    os.makedirs(overall_out_dir, exist_ok=True)

    for dpath in dset_paths:
        dataset_name = os.path.basename(dpath)
        print(f"\n[INFO] Processing dataset: {dataset_name}")

        records = []
        with open(dpath, "r", encoding="utf-8") as f_in:
            for line_idx, line in enumerate(f_in):
                if args.max_examples and line_idx >= args.max_examples:
                    break
                try:
                    records.append(json.loads(line))
                except Exception:
                    continue  # skip malformed

        # --- Apply slicing if --start/--end are provided ---
        slice_start = max(0, args.start) if hasattr(args, 'start') and args.start is not None else 0
        slice_end = args.end if hasattr(args, 'end') and args.end is not None else len(records) - 1
        slice_end = min(slice_end, len(records) - 1)
        if slice_start > 0 or slice_end < len(records) - 1:
            records = records[slice_start : slice_end + 1]
            print(f"[INFO] Processing slice {slice_start}-{slice_end} (n={len(records)}) of {dataset_name}")
        else:
            print(f"[INFO] Processing full dataset {dataset_name} (n={len(records)})")

        slice_suffix = f"_{slice_start}_{slice_end}"
        # Use a more specific output filename to avoid clashes
        out_path = os.path.join(overall_out_dir, f"gmo_results_{dataset_name}{slice_suffix}.json")

        all_results = []

        for idx, rec in enumerate(tqdm(records, desc=dataset_name)):
            # The 'prompt' field from the dataset contains the ICL examples and query.
            ad_prompt = rec.get("prompt", "")
            gt_resp = rec.get("response", None)  # ground-truth CTR percentile if provided
            log_msg = "[INFO] Using combined persona and ad prompt."

            resp_text = verbalize(ad_prompt, model, tokenizer, args)

            # Extract score
            num_match = re.search(r"(?i)answer[^0-9]{0,10}(\d{1,3}(?:\.\d+)?)", resp_text)
            score = float(num_match.group(1)) if num_match else None
            if score is not None:
                score = max(0.0, min(100.0, score))

            all_results.append({
                "prompt": ad_prompt,
                "ground_truth": gt_resp,
                "response": resp_text,
                "predicted_score": score,
            })

            # Incremental save after every example to avoid data loss
            try:
                with open(out_path, "w", encoding="utf-8") as f_out_inc:
                    json.dump(all_results, f_out_inc, indent=2)
            except Exception as _e:
                print(f"[WARNING] Incremental save failed: {_e}")

        # — Final save per-dataset results
        with open(out_path, "w", encoding="utf-8") as f_out:
            json.dump(all_results, f_out, indent=2)

        print(f"[INFO] Completed processing {dataset_name} slice {slice_start}-{slice_end}. Results saved to {out_path}")

    print("\n[INFO] GMO ad evaluation complete.")

if __name__ == "__main__":
    main()


            

            

        
        
        

        
